from ensemble_boxes import *
import json, copy
from numba.core.types.scalars import EnumMember
from tqdm import tqdm





def wbf(josn_list, score):
    iou_thr = 0.3
    skip_box_thr = 0.1
    sigma = 0.1

    ori_weights = get_weights(score)
    imgid_info = get_image_info('./image_info_val.json')

    result_list = []
    for json_path in josn_list:
        result = get_bbox_ann(json_path)
        result_list.append(result)
    print("start fusion....")
    # 按图片fusion
    new_result = []
    for img_id in tqdm(imgid_info):
        boxes_list = []
        scores_list = []
        labels_list = []
        weights = copy.deepcopy(ori_weights)
        W, H = imgid_info[img_id]
        for i, res in enumerate(result_list):
            if img_id in res:
                boxes_list.append(res[img_id]['boxes'])
                scores_list.append(res[img_id]['scores'])
                labels_list.append(res[img_id]['labels'])
            else:
                if i < len(weights):
                    del weights[i]

        # boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, 
        #                                                 labels_list, weights=weights, 
        #                                                 iou_thr=iou_thr, skip_box_thr=skip_box_thr)
        
        boxes, scores, labels = soft_nms(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, sigma=sigma, thresh=skip_box_thr)
        # 反归一化

        for i, box in enumerate(boxes):
            new_result.append({
                'image_id': img_id,
                'category_id': int(labels[i]),
                "score": float(scores[i]),
                "bbox": [float(box[0]*W), float(box[1]*H), float((box[2]-box[0])*W), float((box[3]-box[1])*H)]
            })


    print("**************长度变化*****************")
    for i, res in enumerate(result_list):
        if i == 0:
            print(" " + str(len(res)))
        else:
            print("+" + str(len(res)))
    print('---------')
    print(" " + str(len(new_result)))
    json_str = json.dumps(new_result)
    with open('./test_soft_nms.json', 'w') as json_file:
        json_file.write(json_str)

    
    

def get_image_info(image_info_path):
    with open(image_info_path, 'r') as f:
        image_info = json.load(f)['images']
    
    imgid_info = {}
    for image in image_info:
        imgid_info[image['id']] = [image['width'], image['height']]
 
    return imgid_info
    
def get_weights(score):
    factor = max(score)
    weights = [s / factor for s in score]
    return weights

def get_bbox_ann(json_path):
    with open(json_path, 'r') as f:
        bbox_data = json.load(f)

    imgid_info = get_image_info('./image_info_val.json')

    result = {}

    # 按图片id处理
    for ann in bbox_data:
        W, H = imgid_info[ann['image_id']]
        bbox = ann['bbox']
        
        labels_list = ann['category_id']
        scores_list = ann['score']
        boxes_list = [bbox[0] / W, bbox[1] / H, 
                     (bbox[0] + bbox[2]) / W, 
                     (bbox[1] + bbox[3]) / H]

        if ann['image_id'] not in result:
            result[ann['image_id']] = {
                "boxes": [],
                "scores": [],
                "labels": []
            }
            result[ann['image_id']]["boxes"].append(boxes_list)
            result[ann['image_id']]["scores"].append(scores_list)
            result[ann['image_id']]["labels"].append(labels_list)
        else:
            result[ann['image_id']]["boxes"].append(boxes_list)
            result[ann['image_id']]["scores"].append(scores_list)
            result[ann['image_id']]["labels"].append(labels_list)


    return result



if __name__ == '__main__':
    josn_list = ['./bbox_0.json',
                './bbox_1.json',
                './bbox_2.json',
                './bbox_3.json']

    score = [1, 1, 1, 1]
    print('start Soft_NMS.....')
    wbf(josn_list, score)